import pandas as pd
import numpy as np
from scipy.optimize import linprog
import sys

# prepare bin-level data for linear program
def prep_bin_level(blc):
    tot_b = (blc['share']*blc['p_black']).sum()
    tot_nb = 1-tot_b
    blc['audit_rate']=blc['audited']
    blc['p_b_given_B'] = blc['share']*(blc['p_black'])/tot_b
    blc['p_b_given_NB'] = blc['share']*(1-blc['p_black'])/tot_nb
    blc['p_nonblack'] = 1-blc['p_black']
    blc['pred_share_black'] = blc['p_black']*blc['share']/tot_b
    blc['pred_share_nonblack'] = (1-blc['p_black'])*blc['share']/tot_nb
    blc['lower_bound'] = ((blc['audit_rate']-(1-blc['p_black']))/blc['p_black']).clip(lower=0)
    blc['upper_bound'] = (blc['audit_rate']/blc['p_black']).clip(upper=1)
    return blc

# runs linear program
def run_program(blc,max_disp=True):
    n = len(blc)
    # Constants
    p = blc.share
    b = blc.p_black
    Y = blc.audited
    b_bar = (blc['p_black']*blc['p_b_given_B']).sum()
    b_bar_NB = (blc['p_black']*blc['p_b_given_NB']).sum()
    # The coefficients for the objective function (we need to maximize, hence the negative sign)
    if max_disp:
        c = -p * b
    else:
        c = p * b
    # Bounds for each variable Y_i^B (same for all i in this case)
    bounds = [(blc['lower_bound'][i],blc['upper_bound'][i]) for i in range(n)]
    # Inequality constraints (A_ub @ x <= b_ub)
    A_ub_first = -1*p*b
    A_ub_second = -1*p*b*(b_bar_NB-b_bar)
    b_ub_first = -1*np.sum(p*b*Y) # 1?
    b_ub_second = 1*np.sum(p*Y*(b-b_bar_NB))
    # Combine the constraints
    A_ub = np.stack((A_ub_first, A_ub_second))
    b_ub = np.hstack((b_ub_first,b_ub_second)) # Right side of inequalities, note the negative sign for the first set
    res = linprog(c,A_ub=A_ub, b_ub=b_ub, bounds=bounds, method='highs')
    # check if optimization was successful
    if res.success:
        print('Optimal value:', -res.fun) # negataxpayer_idg back because we negated c
        print('Optimal value:', res.x)
    else:
        print('Optimization failed:', res.message)
    return res

def print_out_results(blc):
    program_ar = blc['audits'].sum()/(blc['people_black']+blc['people_nonblack']).sum()
    program_ar_black = blc['audits_black'].sum()/blc['people_black'].sum()
    program_ar_nonblack = blc['audits_nonblack'].sum()/blc['people_nonblack'].sum()
    print('Black Audit Rate: ', round(100*program_ar_black, 3))
    print('Non-Black Audit Rate: ', round(100*program_ar_nonblack, 3))
    print('Disparity: ', round(100*(program_ar_black-program_ar_nonblack), 3))
    print('Audit Rate Ratio: ', round(program_ar_black / program_ar_nonblack, 1))

def update_binlevel_with_program(blc,res):
    blc['program_audit_rate_black'] = res.x
    blc['program_audit_rate_nonblack'] = (blc['audit_rate']-blc['program_audit_rate_black']*blc['p_black'])/(1-blc['p_black'])
    blc['people_black'] = blc['p_black']*blc['share']
    blc['people_nonblack'] = (1-blc['p_black'])*blc['share']
    blc['audits_black'] = blc['people_black']*blc['program_audit_rate_black']
    blc['audits_nonblack'] = blc['people_nonblack']*blc['program_audit_rate_nonblack']
    blc['audits'] = blc['audits_black']+blc['audits_nonblack']
    return blc

# read in stats from different bins based on BIFSG-predicted probability black
# from sharpness_table.py
bin_level = pd.read_csv('/REDACTED/data/final/sharpness_table_20.csv')
bin_level.columns=['share','audited','p_black','bin']

blc = bin_level

blc = prep_bin_level(blc)

blc.columns

cres_max = run_program(blc,True)
cres_min = run_program(blc,False)

# write out results
sys.stdout = open("/REDACTED/linear_program_results.txt", "w")
print("Maximum Disparity")
blc_max = update_binlevel_with_program(blc, cres_max)
print_out_results(blc_max)

print("\nMinimum Disparity")
blc_min = update_binlevel_with_program(blc, cres_min)
print_out_results(blc_min)

sys.stdout = sys.__stdout__